import torch
import torch.nn as nn

class MyLinear(nn.Module):

  def __init__(self, inp, outp):
    super(MyLinear, self).__init__()

    # requires_grad = True
    self.w = nn.Parameter(torch.rand(outp, inp))
    self.b = nn.Parameter(torch.randn(outp))

  def forward(self, x):
    x = x @ self.w.t() + self.b
    return x